from sample_with_numeral_1p import *
import json
from multiprocessing import Pool
import numpy as np
import pickle

num_processes = 72
create_line_predictor_data=True

def sample_all_data(id):
    n_queries_train_dict = n_queries_train_dict_same
    n_queries_valid_test_dict = n_queries_valid_test_dict_same

   

    """    first_round_query_types = {
        "e_1p": "(p,(e))",
        "e_2p": "(p,(p,(e)))",
        "e_2i": "(i,(p,(e)),(p,(e)))",
        "e_3i": "(i,(p,(e)),(p,(e)),(p,(e)))",
        "e_ip": "(p,(i,(p,(e)),(p,(e))))",
        "e_pi": "(i,(p,(p,(e))),(p,(e)))",
        "e_2u": "(u,(p,(e)),(p,(e)))",
        "e_up": "(p,(u,(p,(e)),(p,(e))))",
        "n_1p": "(p,(e))",
        "n_2p": "(p,(p,(e)))",
        "n_3p": "(p,(p,(p,(e))))",
        "n_i2pp": "(i,(p,(p,(e))),(p,(p,(e))))",
        "n_i3pp": "(i,(p,(p,(e))),(p,(p,(e))),(p,(p,(e))))",
        "n_u2p": "(u,(p,(e)),(p,(e)))",
        "n_u3p": "(u,(p,(e)),(p,(e)),(p,(e)))",
        "n_u2pp": "(u,(p,(p,(e))),(p,(p,(e))))",
        "n_u3pp": "(u,(p,(p,(e))),(p,(p,(e))),(p,(p,(e))))",
    }"""
    first_round_query_types = {
        "n_1p": "(p,(e))",
        "n_2p": "(p,(p,(e)))",
        "n_3p": "(p,(p,(p,(e))))",
        "n_i2pp": "(i,(p,(p,(e))),(p,(p,(e))))",
        "n_i3pp": "(i,(p,(p,(e))),(p,(p,(e))),(p,(p,(e))))",
        "n_u2p": "(u,(p,(e)),(p,(e)))",
        "n_u3p": "(u,(p,(e)),(p,(e)),(p,(e)))",
        "n_u2pp": "(u,(p,(p,(e))),(p,(p,(e))))",
        "n_u3pp": "(u,(p,(p,(e))),(p,(p,(e))),(p,(p,(e))))",
    }
    
    """    first_round_query_types = {
        "1p": "(p,(e))",
        "2p": "(p,(p,(e)))",
        "3p": "(p,(p,(p,(e))))",
        "i2pp": "(i,(p,(p,(e))),(p,(p,(e))))",
        "i3pp": "(i,(p,(p,(e))),(p,(p,(e))),(p,(p,(e))))",
        "u2p": "(u,(p,(e)),(p,(e)))",
        "u3p": "(u,(p,(e)),(p,(e)),(p,(e)))",
        "u2pp": "(u,(p,(p,(e))),(p,(p,(e))))",
        "u3pp": "(u,(p,(p,(e))),(p,(p,(e))),(p,(p,(e))))",
    }"""


   

    
    for data_dir in n_queries_train_dict.keys():

        print("Load Train Graph " + data_dir)
        train_path = "./" + data_dir + "_train_with_units.pkl"
        train_graph = nx.read_gpickle(train_path)

        relation_edges_counter = 0
        attribute_edges_counter = 0
        reverse_attribute_edges_counter = 0
        numerical_edges_counter = 0
        for u, v, a in train_graph.edges(data=True):
            if isinstance(u, tuple) and isinstance(v, tuple):
                numerical_edges_counter += len(a)
            elif isinstance(u, tuple):
                reverse_attribute_edges_counter += len(a)
            elif isinstance(v, tuple):
                attribute_edges_counter += len(a)
            elif isinstance(u, str) and isinstance(v, str):
                relation_edges_counter += len(a)

        print("#nodes: ", train_graph.number_of_nodes())   #27250
        print("#relation edges: ", relation_edges_counter)  #947540
        print("#attribute edges: ", attribute_edges_counter) #20248
        print("#reverse attribute edges: ", reverse_attribute_edges_counter)  #20248
        print("#numerical edges: ", numerical_edges_counter)   #0
        print("#all edges: ", relation_edges_counter + attribute_edges_counter +
                reverse_attribute_edges_counter + numerical_edges_counter)   #988036

        print("Load Valid Graph " + data_dir)
        valid_path = "./" + data_dir +  "_valid_with_units.pkl"
        valid_graph = nx.read_gpickle(valid_path)

        relation_edges_counter = 0
        attribute_edges_counter = 0
        reverse_attribute_edges_counter = 0
        numerical_edges_counter = 0
        for u, v, a in valid_graph.edges(data=True):
            if isinstance(u, tuple) and isinstance(v, tuple):
                numerical_edges_counter += len(a)
            elif isinstance(u, tuple):
                reverse_attribute_edges_counter += len(a)
            elif isinstance(v, tuple):
                attribute_edges_counter += len(a)
            elif isinstance(u, str) and isinstance(v, str):
                relation_edges_counter += len(a)

        print("number of nodes: ", len(valid_graph.nodes))  #28467

        print("#relation edges: ", relation_edges_counter)  #1065982
        print("#attribute edges: ", attribute_edges_counter)  #22779
        print("#reverse attribute edges: ", reverse_attribute_edges_counter)  #22779
        print("#numerical edges: ", numerical_edges_counter)  #0
        print("#all edges: ", relation_edges_counter + attribute_edges_counter +
                reverse_attribute_edges_counter + numerical_edges_counter)  #1111540

        print("Load Test Graph " + data_dir)
        test_path = "./" + data_dir + "_test_with_units.pkl"
        test_graph = nx.read_gpickle(test_path)

        relation_edges_counter = 0
        attribute_edges_counter = 0
        reverse_attribute_edges_counter = 0
        numerical_edges_counter = 0
        for u, v, a in test_graph.edges(data=True):
            if isinstance(u, tuple) and isinstance(v, tuple):
                numerical_edges_counter += len(a)
            elif isinstance(u, tuple):
                reverse_attribute_edges_counter += len(a)
            elif isinstance(v, tuple):
                attribute_edges_counter += len(a)
            elif isinstance(u, str) and isinstance(v, str):
                relation_edges_counter += len(a)

        print("number of nodes: ", len(test_graph.nodes))  #29640
        print("#relation edges: ", relation_edges_counter) #1184426
        print("#attribute edges: ", attribute_edges_counter)  #25311
        print("#reverse attribute edges: ", reverse_attribute_edges_counter)  #25311
        print("#numerical edges: ", numerical_edges_counter)  #0
        print("#all edges: ", relation_edges_counter + attribute_edges_counter +
                reverse_attribute_edges_counter + numerical_edges_counter) #1235048

        # Print example edges:
        for u, v, a in test_graph.edges(data=True):
            if isinstance(u, tuple) and isinstance(v, tuple):
                numerical_edges_counter += len(a)
                print("example numerical edge: ", u, v, a)
                break

        for u, v, a in test_graph.edges(data=True):
            if isinstance(u, tuple) and isinstance(v, str):
                reverse_attribute_edges_counter += len(a)
                print("example reverse attribute edge: ", u, v, a)
                break

        for u, v, a in test_graph.edges(data=True):
            if isinstance(v, tuple) and isinstance(u, str):
                attribute_edges_counter += len(a)
                print("example attribute edge: ", u, v, a)
                break

        for u, v, a in test_graph.edges(data=True):
            if isinstance(v, str) and isinstance(u, str):
                attribute_edges_counter += len(a)
                print("example relation edge: ", u, v, a)
                break

        all_typed_values = {}
        for u in test_graph.nodes():
            if isinstance(u, tuple):
                if u[1] not in all_typed_values:
                    all_typed_values[u[1]] = []
                all_typed_values[u[1]].append(u[0])

        num_min_max={}
        for i in range(len(all_typed_values)):
            numlist=all_typed_values[i]
            min=np.min(numlist)
            max=np.max(numlist)
            #min=np.percentile(numlist, 3)
            #max=np.percentile(numlist,97)
            
            num_min_max[i]={}
            num_min_max[i]["min"]=min
            num_min_max[i]["max"]=max
        
        all_typed_stdev={}
        for i in range(len(all_typed_values)):
            this_type_nums=np.array(all_typed_values[i])
            min=num_min_max[i]["min"]
            max=num_min_max[i]["max"]
            this_type_nums=(this_type_nums-min)/(max-min)
            all_typed_stdev[i]=np.std(this_type_nums)
        
        
        
        train_graph_sampler = GraphSamplerE34(train_graph)
        valid_graph_sampler = GraphSamplerE34(valid_graph)
        test_graph_sampler = GraphSamplerE34(test_graph)

        print("sample training queries")

        train_queries = {}

        def sample_train_graph_with_pattern(pattern,query_type):
            while True:

                sampled_train_query = train_graph_sampler.sample_with_pattern(pattern,query_type)

                train_query_train_answers = train_graph_sampler.query_search_answer(sampled_train_query)
                if len(train_query_train_answers) > 0:
                    break
            if isinstance(train_query_train_answers[0],str):
                return sampled_train_query, train_query_train_answers
            else:
                a,b=zip(*train_query_train_answers)
                mean=np.mean(a)
                train_query_train_answers.append(mean)
                return sampled_train_query, train_query_train_answers

        def sample_valid_graph_with_pattern(pattern,query_type):
            while True:
                if pattern=="(p,(e))" and query_type[0]=="n":
                    sampled_valid_query,_ = valid_graph_sampler.sample_with_pattern(pattern,query_type)
                    valid_query_valid_answers = valid_graph_sampler.query_search_answer(sampled_valid_query)
                    valid_query_train_answers = train_graph_sampler.query_search_answer(sampled_valid_query)
                    if len(valid_query_train_answers) == 0 and len(valid_query_valid_answers) > 0:
                        break
                else:

                    sampled_valid_query,_ = valid_graph_sampler.sample_with_pattern(pattern,query_type)

                    valid_query_train_answers = train_graph_sampler.query_search_answer(sampled_valid_query)
                    valid_query_valid_answers = valid_graph_sampler.query_search_answer(sampled_valid_query)

                    if len(valid_query_train_answers) > 0 and len(valid_query_valid_answers) > 0 \
                            and len(valid_query_train_answers) != len(valid_query_valid_answers):
                        break
            if isinstance(valid_query_valid_answers[0],str):
                return sampled_valid_query, valid_query_train_answers, valid_query_valid_answers

            else:
                if len(valid_query_train_answers)==0:
                    valid_query_train_answers.append(0)
                else:
                    a,b=zip(*valid_query_train_answers)
                    valid_query_train_answers.append(np.mean(a))
                a,b=zip(*valid_query_valid_answers)
                valid_query_valid_answers.append(np.mean(a))
                return sampled_valid_query, valid_query_train_answers, valid_query_valid_answers

        def sample_test_graph_with_pattern(pattern,query_type):
            while True:
                if pattern=="(p,(e))"and query_type[0]=="n":
                    sampled_test_query,_ = test_graph_sampler.sample_with_pattern(pattern,query_type)
                    test_query_test_answers = test_graph_sampler.query_search_answer(sampled_test_query)
                    test_query_valid_answers = valid_graph_sampler.query_search_answer(sampled_test_query)
                    if len(test_query_valid_answers) == 0 and len(test_query_test_answers) > 0:
                        break
                
                else:
                    sampled_test_query,_ = test_graph_sampler.sample_with_pattern(pattern,query_type)
                    test_query_valid_answers = valid_graph_sampler.query_search_answer(sampled_test_query)
                    test_query_test_answers = test_graph_sampler.query_search_answer(sampled_test_query)

                    if  len(test_query_valid_answers) > 0 and len(test_query_test_answers) > 0\
                            and len(test_query_test_answers) != len(test_query_valid_answers):
                        break
            
            if isinstance(test_query_test_answers[0],str):
                return sampled_test_query, test_query_valid_answers, test_query_test_answers
            else:
                if len(test_query_valid_answers)==0:
                    test_query_valid_answers.append(0)
                else:
                    a,b=zip(*test_query_valid_answers)
                    test_query_valid_answers.append(np.mean(a))
                a,b=zip(*test_query_test_answers)
                test_query_test_answers.append(np.mean(a))
                return sampled_test_query, test_query_valid_answers, test_query_test_answers

        #生成task2数据集
        def sample_task2_query_with_pattern(pattern,query_type):
            #采样一个
            while True:
                while True:
                    if pattern=="(p,(e))"and query_type[0]=="n":
                        sampled_test_query,attribution = test_graph_sampler.sample_with_pattern(pattern,query_type)
                        test_query_test_answers = test_graph_sampler.query_search_answer(sampled_test_query)
                        test_query_train_answers = train_graph_sampler.query_search_answer(sampled_test_query)
                        if len(test_query_train_answers) == 0 and len(test_query_test_answers) > 0:
                            break
                    
                    else:
                        sampled_test_query,attribution = test_graph_sampler.sample_with_pattern(pattern,query_type)
                        test_query_train_answers = train_graph_sampler.query_search_answer(sampled_test_query)
                        test_query_test_answers = test_graph_sampler.query_search_answer(sampled_test_query)

                        if  len(test_query_train_answers) > 0 and len(test_query_test_answers) > 0\
                                and len(test_query_test_answers) != len(test_query_train_answers):
                            break
                
                
                
                a,b=zip(*test_query_test_answers)
                test_graph_nums_mean=np.mean(a)
                min=num_min_max[attribution/2]["min"]
                max=num_min_max[attribution/2]["max"]
                mean_norm=(test_graph_nums_mean-min)/(max-min)
                
                test_graph_answer_set=test_graph_sampler.get_all_satisfied_entities(attribution,mean_norm,min,max,all_typed_stdev[attribution/2])
                train_graph_answer_set=train_graph_sampler.get_all_satisfied_entities(attribution,mean_norm,min,max,all_typed_stdev[attribution/2])
                
                easy_answer_set=train_graph_answer_set
                hard_answer_set=test_graph_answer_set-train_graph_answer_set
                if len(easy_answer_set)!=0 and len(hard_answer_set)!=0:
                    break

            return sampled_test_query,attribution,easy_answer_set,hard_answer_set
                
            
            

        if create_line_predictor_data:
            for query_type, sample_pattern in first_round_query_types.items():
                print("line_predictor test query_type: ", query_type)
                this_type_train_queries = {}
                n_query = n_queries_valid_test_dict[data_dir] // num_processes
                for _ in tqdm(range(n_query)):
                    sampled_test_query,attribution, easy_answer_set,hard_answer_set = sample_task2_query_with_pattern(sample_pattern,query_type)
                    this_type_train_queries[sampled_test_query] = {"attribution":attribution,
                                                                   "easy_answer_set":easy_answer_set,
                                                                    "hard_answer_set":hard_answer_set}

                train_queries[query_type+"_"+sample_pattern] = this_type_train_queries
            with open("line_predictor_less_query_split/"+str(id)+"_"+data_dir+"_line_predictor_test_data.pkl","wb") as file:
                pickle.dump(train_queries,file)
            
        
        else:
        
            #以下为task1数据集
            for query_type, sample_pattern in first_round_query_types.items():
                print("train query_type: ", query_type)
                this_type_train_queries = {}
                n_query = n_queries_train_dict[data_dir] // num_processes

                for _ in tqdm(range(n_query)):
                    sampled_train_query, train_query_train_answers = sample_train_graph_with_pattern(sample_pattern,query_type)
                    this_type_train_queries[sampled_train_query] = {"train_answers": train_query_train_answers}

                train_queries[query_type+"_"+sample_pattern] = this_type_train_queries

            with open(
                    "sampled_data_YAGO15K/" + data_dir + "_train_queries_" + str(id) + "_with_units.json",
                    "w") as file_handle:
                json.dump(train_queries, file_handle)

            print("sample validation queries")

            validation_queries = {}
            for query_type, sample_pattern in first_round_query_types.items():
                print("validation query_type: ", query_type)

                this_type_validation_queries = {}

                n_query = n_queries_valid_test_dict[data_dir] // num_processes

            

                for _ in tqdm(range(n_query)):
                    sampled_valid_query, valid_query_train_answers, valid_query_valid_answers = \
                        sample_valid_graph_with_pattern(sample_pattern,query_type)

                    this_type_validation_queries[sampled_valid_query] = {
                        "train_answers": valid_query_train_answers,
                        "valid_answers": valid_query_valid_answers
                    }

                validation_queries[query_type+"_"+sample_pattern] = this_type_validation_queries

            with open(
                    "sampled_data_YAGO15K/" + data_dir  + "_valid_queries_" + str(id) + "_with_units.json",
                    "w") as file_handle:
                json.dump(validation_queries, file_handle)

            print("sample testing queries")

            test_queries = {}
            for query_type, sample_pattern in first_round_query_types.items():
                print("test query_type: ", query_type)
                this_type_test_queries = {}

                n_query = n_queries_valid_test_dict[data_dir] // num_processes

                
                for _ in tqdm(range(n_query)):
                    sampled_test_query, test_query_valid_answers, test_query_test_answers = \
                        sample_test_graph_with_pattern(sample_pattern,query_type)

                    this_type_test_queries[sampled_test_query] = {
                        "valid_answers": test_query_valid_answers,
                        "test_answers": test_query_test_answers
                    }

                test_queries[query_type+"_"+sample_pattern] = this_type_test_queries
            with open(
                    "sampled_data_YAGO15K/" + data_dir +  "_test_queries_" + str(id) + "_with_units.json",
                    "w") as file_handle:
                json.dump(test_queries, file_handle)


if __name__ == "__main__":
    #sample_all_data(1)
    with Pool(num_processes) as p:
        print(p.map(sample_all_data, range(num_processes)))
